
% function takes input trajectories and filters them based on the
% difference between the baseline and stimulus 
%
% inputs:
% y: the data array, structured where each row is a trajectory and each
% column is a timepoint
% baselineframes: the frames to use as the baseline
% stimframes: the frames to use as the stimulus
% thresholds: 3 element array with the threshold for the z filtering, power
% filtering, and anova filtering respectively
% spectralband: the frequency band (in Hz) to use for filtering
%
% outputs:
% responsive_trajs: a 3 column array where each row is a boolean indicating
% if that trial counts as responsive with the corresponding threshold
% 
% stimpower: the spectral power in the stimulus window at the
% dilation/contraction frequencies
% baselinepower: the spectral power in the  baseline window at the
% dilation/contraction frequencies
% noisybaseline_trajs: an estimate of which trajectories have too much
% noise in their baselines to be considered a valid trajectory

function [responsive_trajs,stimpower,baselinepower,noisybaseline_trajs] = trajectoryfilter_v1_1(y,baselineframes,stimframes,thresholds,spectralband,yz,spectrogramproperties)

if nargin < 2 || isempty(baselineframes)
    baselineframes = 10:45;
end
if nargin < 3 || isempty(stimframes)
    stimframes = 46:85;
end
if nargin < 4 || isempty(thresholds)
    thresholds = [4,2.5,.01,3];        %z-score, power, anova, z-score noise thresholds
elseif length(thresholds) == 3
    thresholds = [thresholds,3];        %add in a fourth parameter that is the z-score threshold for considering the baseline too noisy; this is only going to work if you're getting the baseline z score from many repeats on a vessel instead of that one vessel
end
if nargin < 5 || isempty(spectralband)
    spectralband = [.5,1.3];  %in Hz
end
if nargin < 7
    spectrogramproperties.fs = 10;
    spectrogramproperties.windowsize = 23;
    spectrogramproperties.overlap = 0;
    spectrogramproperties.storedbins = 1:5;
    spectrogramproperties.numfreqpoints = 21;
end
responsive_trajs = nan(size(y,1),3);
noisybaseline_trajs = nan(size(y,1),1);
noisybaselineframes = 30:50;            %hard coding this number in here assuming all trajectories will have 4.5 second baseline and recorded at 10fps 
%% z-score calculation
if nargin < 6 || isempty(yz)
    if ~isnan(thresholds(1))
        Z = @(x,idx) (x-repmat(mean(x(:,idx),2,'omitnan'),1,size(x,2)))./repmat(std(x(:,idx),0,2,'omitnan'),1,size(x,2));
        yz = Z(y,baselineframes);     %z-scored trajectories
        responsive_trajs(:,1) = prctile(abs(yz(:,stimframes)),96,2) > thresholds(1);
        noisybaseline_trajs(:,1) = prctile(abs(yz(:,noisybaselineframes)),96,2) > thresholds(4);
        noisybaseline_trajs = logical(noisybaseline_trajs);
    end
else
    responsive_trajs(:,1) = prctile(abs(yz(:,stimframes)),96,2) > thresholds(1);
    noisybaseline_trajs(:,1) = prctile(abs(yz(:,noisybaselineframes)),96,2) > thresholds(4);
    noisybaseline_trajs = logical(noisybaseline_trajs);
end
%% power calculation:
if ~isnan(thresholds(2))
    fs = spectrogramproperties.fs;
    windowsize = spectrogramproperties.windowsize;
    overlap = spectrogramproperties.overlap;
    storedbins = spectrogramproperties.storedbins;
    numfreqpoints = spectrogramproperties.numfreqpoints;
    powerband = nan(numfreqpoints,length(storedbins),size(y,1));
    for n = 1:size(y,1)
        cur = y(n,:);
        if sum(isnan(cur))/length(cur) > .25
            continue;
        end
        [s,f,~] = spectrogram(cur,windowsize,overlap,[],fs);
        s = abs(s);
        idx = f>spectralband(1) & f<spectralband(2);
        powerband(:,:,n) = s(idx,storedbins);
    end
    pwrratio = powerband(:,3,:)./powerband(:,2,:);
    pwrratio = mean(pwrratio,1);
    pwrratio = reshape(pwrratio,numel(pwrratio),1);
    stimpower = mean(powerband(:,3,:),1);
    stimpower = reshape(stimpower,numel(stimpower),1);
    baselinepower = mean(powerband(:,2,:),1);
    baselinepower = reshape(baselinepower,numel(baselinepower),1);
end
responsive_trajs(:,2) = pwrratio > thresholds(2);
%% anova calculation
if ~isnan(thresholds(3))
    pval = zeros(size(y,1),1);
    for n = 1:size(y,1)
        a = y(n,baselineframes);
        b = y(n,stimframes);
        temp = nancat(a',b',2);
        pval(n) = anova1(temp,{'b','s'},'off');
    end
    responsive_trajs(:,3) = pval < thresholds(3);
end

%% go through trajectories and make sure that you're not including a bunch with a ton of NaNs
nanthresh = .25;
temp = isnan(y);
temp = sum(temp,2);
temp = temp/size(y,2);
temp = temp>nanthresh;
responsive_trajs(temp,:) = false;
